{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "xGy4FpKqat4m",
    "colab_type": "code",
    "outputId": "85104072-315a-4ee1-9b79-cd6b639a7bbe",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 224.0
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E: Package 'python-software-properties' has no installation candidate\n",
      "Selecting previously unselected package google-drive-ocamlfuse.\n",
      "(Reading database ... 130912 files and directories currently installed.)\n",
      "Preparing to unpack .../google-drive-ocamlfuse_0.7.3-0ubuntu3~ubuntu18.04.1_amd64.deb ...\n",
      "Unpacking google-drive-ocamlfuse (0.7.3-0ubuntu3~ubuntu18.04.1) ...\n",
      "Setting up google-drive-ocamlfuse (0.7.3-0ubuntu3~ubuntu18.04.1) ...\n",
      "Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n",
      "Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&response_type=code&access_type=offline&approval_prompt=force\n",
      "··········\n",
      "Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&response_type=code&access_type=offline&approval_prompt=force\n",
      "Please enter the verification code: Access token retrieved correctly.\n"
     ]
    }
   ],
   "source": [
    "!apt-get install -y -qq software-properties-common python-software-properties module-init-tools\n",
    "!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null\n",
    "!apt-get update -qq 2>&1 > /dev/null\n",
    "!apt-get -y install -qq google-drive-ocamlfuse fuse\n",
    "from google.colab import auth\n",
    "auth.authenticate_user()\n",
    "from oauth2client.client import GoogleCredentials\n",
    "creds = GoogleCredentials.get_application_default()\n",
    "import getpass\n",
    "!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL\n",
    "vcode = getpass.getpass()\n",
    "!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "kH1i7-fha8J3",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!mkdir -p drive\n",
    "!google-drive-ocamlfuse drive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "4pPkfWfEpws2",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/bgshih/aster.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "pKWHkEpipu8x",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!sudo apt install cmake libcupti-dev\n",
    "!pip install --user protobuf tqdm numpy editdistance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "GZ9F7lHipsnL",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!cd aster/c_ops;sh build.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "QyloNwUYbe5A",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!protoc aster/protos/*.proto --python_out=."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xohvFXU0H-7B",
    "colab_type": "text"
   },
   "source": [
    "##制作训练数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "rcs30pTBdMhs",
    "colab_type": "code",
    "outputId": "929f57c7-04b6-4e05-fc8c-cae7e0c3b4c6",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34.0
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2081 examples created\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import io\n",
    "import copy\n",
    "import random\n",
    "import re\n",
    "import xml.etree.ElementTree as ET\n",
    "import sys\n",
    "import pandas as pd\n",
    "\n",
    "from PIL import Image, ImageFilter\n",
    "import tensorflow as tf\n",
    "\n",
    "data_dir='./'\n",
    "ignore_difficult=True\n",
    "crop_margin=0.2\n",
    "lexicon_size = 50\n",
    "random.seed(1)\n",
    "\n",
    "class InputDataFields(object):\n",
    "  image = 'image'\n",
    "  original_image = 'original_image'\n",
    "  key = 'key'\n",
    "  source_id = 'source_id'\n",
    "  filename = 'filename'\n",
    "  groundtruth_text = 'groundtruth_text'\n",
    "  groundtruth_keypoints = 'groundtruth_keypoints'\n",
    "  lexicon = 'lexicon'\n",
    "\n",
    "\n",
    "class TfExampleFields(object):\n",
    "  image_encoded = 'image/encoded'\n",
    "  image_format = 'image/format'  # format is reserved keyword\n",
    "  filename = 'image/filename'\n",
    "  channels = 'image/channels'\n",
    "  colorspace = 'image/colorspace'\n",
    "  height = 'image/height'\n",
    "  width = 'image/width'\n",
    "  source_id = 'image/source_id'\n",
    "  transcript = 'image/transcript'\n",
    "  lexicon = 'image/lexicon'\n",
    "  keypoints = 'image/keypoints'\n",
    "  \n",
    "def int64_feature(value):\n",
    "  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))\n",
    "\n",
    "\n",
    "def int64_list_feature(value):\n",
    "  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))\n",
    "\n",
    "\n",
    "def bytes_feature(value):\n",
    "  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "\n",
    "def bytes_list_feature(value):\n",
    "  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))\n",
    "\n",
    "\n",
    "def float_list_feature(value):\n",
    "  return tf.train.Feature(float_list=tf.train.FloatList(value=value))\n",
    "\n",
    "\n",
    "def read_examples_list(path):\n",
    "  \"\"\"Read list of training or validation examples.\n",
    "\n",
    "  The file is assumed to contain a single example per line where the first\n",
    "  token in the line is an identifier that allows us to find the image and\n",
    "  annotation xml for that example.\n",
    "\n",
    "  For example, the line:\n",
    "  xyz 3\n",
    "  would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored).\n",
    "\n",
    "  Args:\n",
    "    path: absolute path to examples list file.\n",
    "\n",
    "  Returns:\n",
    "    list of example identifiers (strings).\n",
    "  \"\"\"\n",
    "  with tf.gfile.GFile(path) as fid:\n",
    "    lines = fid.readlines()\n",
    "  return [line.strip().split(' ')[0] for line in lines]\n",
    "\n",
    "\n",
    "def recursive_parse_xml_to_dict(xml):\n",
    "  \"\"\"Recursively parses XML contents to python dict.\n",
    "\n",
    "  We assume that `object` tags are the only ones that can appear\n",
    "  multiple times at the same level of a tree.\n",
    "\n",
    "  Args:\n",
    "    xml: xml tree obtained by parsing XML file contents using lxml.etree\n",
    "\n",
    "  Returns:\n",
    "    Python dictionary holding XML contents.\n",
    "  \"\"\"\n",
    "  if not xml:\n",
    "    return {xml.tag: xml.text}\n",
    "  result = {}\n",
    "  for child in xml:\n",
    "    child_result = recursive_parse_xml_to_dict(child)\n",
    "    if child.tag != 'object':\n",
    "      result[child.tag] = child_result[child.tag]\n",
    "    else:\n",
    "      if child.tag not in result:\n",
    "        result[child.tag] = []\n",
    "      result[child.tag].append(child_result[child.tag])\n",
    "  return {xml.tag: result}\n",
    "\n",
    "def _random_lexicon(lexicon_list, groundtruth_text, lexicon_size):\n",
    "    lexicon = copy.deepcopy(lexicon_list)\n",
    "    # del lexicon[lexicon.index(groundtruth_text.lower())]\n",
    "    random.shuffle(lexicon)\n",
    "    lexicon = lexicon[:(lexicon_size - 1)]\n",
    "    lexicon.insert(0, groundtruth_text)\n",
    "    return lexicon\n",
    "\n",
    "\n",
    "def _is_difficult(word):\n",
    "    assert isinstance(word, str)\n",
    "    return len(word) < 3 or not re.match('^[\\w]+$', word)\n",
    "\n",
    "\n",
    "def create_ic03(output_path):\n",
    "    writer = tf.python_io.TFRecordWriter(output_path)\n",
    "    lexicon_list = []\n",
    "\n",
    "    data = pd.read_csv('./train_id_label.csv')[38500:]\n",
    "    count = 0\n",
    "    for ii, row in data.iterrows():\n",
    "        image_path = os.path.join('/content/train_cptn_result', row['name'])\n",
    "        if not os.path.exists(image_path):\n",
    "            print('%s does not exist' % image_path)\n",
    "            continue\n",
    "        imag = Image.open(image_path)\n",
    "        image_w, image_h = imag.size\n",
    "        groundtruth_text = row['label'].strip().lower()\n",
    "        bbox_xmin = 0\n",
    "        bbox_ymin = 0\n",
    "        bbox_xmax = image_w\n",
    "        bbox_ymax = image_h\n",
    "\n",
    "        imags = []\n",
    "        if '3' in row['label'].strip() and '8' not in row['label'].strip():\n",
    "#       print(row['label'])\n",
    "          for i in range(4):\n",
    "            imags.append(imag)\n",
    "        elif '5' in row['label'].strip() and '6' not in row['label'].strip() and '9' not in row['label'].strip():\n",
    "          for i in range(2):\n",
    "            imags.append(imag)\n",
    "        else:\n",
    "          imags.append(imag)\n",
    "        for image in imags:\n",
    "            word_crop_im = image.crop((bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax))\n",
    "            im_buff = io.BytesIO()\n",
    "            word_crop_im.save(im_buff, format='jpeg')\n",
    "            word_crop_jpeg = im_buff.getvalue()\n",
    "            crop_name = '{}:{}'.format(row['name'], ii)\n",
    "\n",
    "            lexicon = _random_lexicon(lexicon_list, groundtruth_text, lexicon_size)\n",
    "\n",
    "            example = tf.train.Example(features=tf.train.Features(feature={\n",
    "                TfExampleFields.image_encoded: \\\n",
    "                    bytes_feature(word_crop_jpeg),\n",
    "                TfExampleFields.image_format: \\\n",
    "                    bytes_feature('jpeg'.encode('utf-8')),\n",
    "                TfExampleFields.filename: \\\n",
    "                    bytes_feature(crop_name.encode('utf-8')),\n",
    "                TfExampleFields.channels: \\\n",
    "                    int64_feature(3),\n",
    "                TfExampleFields.colorspace: \\\n",
    "                    bytes_feature('rgb'.encode('utf-8')),\n",
    "                TfExampleFields.transcript: \\\n",
    "                    bytes_feature(groundtruth_text.encode('utf-8')),\n",
    "                TfExampleFields.lexicon: \\\n",
    "                    bytes_feature(('\\t'.join(lexicon)).encode('utf-8')),\n",
    "            }))\n",
    "            writer.write(example.SerializeToString())\n",
    "            count += 1\n",
    "\n",
    "    writer.close()\n",
    "    print('{} examples created'.format(count))\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    create_ic03('./ic03_val.tfrecord')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LiNtp2g2rpEE",
    "colab_type": "text"
   },
   "source": [
    "###colab12小时清空一次，需要重新copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "tFy6hpnYNvLa",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "# !cp ./drive/my_gpu/tinymind/ic03* ./aster/data/\n",
    "!cp -r ./drive/my_gpu/tinymind/aster/core/*.py aster/core/\n",
    "# !mkdir -p aster/experiments/tinymind/log;mkdir -p aster/experiments/tinymind/config\n",
    "# !cp ./drive/my_gpu/tinymind/experiments/tinymind/log/* aster/experiments/tinymind/log/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "Q4As3KJUdZXr",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!export LD_LIBRARY_PATH=/usr/lib64-nvidia:/usr/local/cuda-10.0/extras/CUPTI/lib64;python aster/train_1.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6co9hGdbSOTI",
    "colab_type": "text"
   },
   "source": [
    "### 识别图片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "EjuUTICIcnQT",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!mkdir -p ./aster/experiments/tinymind/log/\n",
    "!mkdir -p ./aster/experiments/tinymind/config/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "sWrKKNi3SNd1",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!cp ./drive/my_gpu/tinymind/experiments/tinymind/log/model.ckpt-110183.data-00000-of-00001 ./aster/experiments/tinymind/log/model.ckpt.data-00000-of-00001\n",
    "!cp ./drive/my_gpu/tinymind/experiments/tinymind/log/model.ckpt-110183.index ./aster/experiments/tinymind/log/model.ckpt.index\n",
    "!cp ./drive/my_gpu/tinymind/experiments/tinymind/log/model.ckpt-110183.meta ./aster/experiments/tinymind/log/model.ckpt.meta\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "tcfqcifb8uwe",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!cp ./aster/experiments/tinymind/log/model.ckpt-110183* ./drive/my_gpu/tinymind/experiments/tinymind/log/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "u4kUGQfiWKXV",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!mkdir aster_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "IQw5mTvlr7Mr",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!python aster/demo_1.py"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "ASERT.ipynb",
   "version": "0.3.2",
   "provenance": [],
   "collapsed_sections": [],
   "toc_visible": true
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "accelerator": "GPU"
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
